import torch
from data.load_cow import generate_cow_renders

class CowDataset(torch.utils.data.Dataset):
    def __init__(self, num_views, epoch_len=1000):
        self.target_cameras, self.target_images, self.target_silhouettes = generate_cow_renders(num_views=num_views, azimuth_range=180)
        self.num_views = num_views
        self.epoch_len = epoch_len
    
    def __getitem__(self, i):
        return {
            'target_camera' : self.target_cameras[i % len(self.target_images)], 
            'target_image' : self.target_images[i % len(self.target_images)], 
            'target_silhouette' : self.target_silhouettes[i % len(self.target_images)]
        }
    
    def __len__(self):
        return self.epoch_len